{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Question Answering System\n",
    "In this example we will be going over the code used to build a question answering system. This example uses a modified BERT model to extract features from questions and Milvus to search for similar questions and answers. \n",
    "\n",
    "> This project based Milvus 2.0.0-rc5."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "This example uses the [InsuranceQA Corpus](https://github.com/shuzi/insuranceQA) dataset, which contains 27,413 answers with the 3,065,492 running words of answers.\n",
    "\n",
    "Download location: https://github.com/chatopera/insuranceqa-corpus-zh/tree/release/corpus/pairs\n",
    "\n",
    "In this example, we use a small subset of the dataset that contains 100 pairs of quesiton-answers, it can be found under the **data** directory."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Requirements\n",
    "\n",
    "\n",
    "|  Packages                  |  Servers      |\n",
    "|-                           | -             |   \n",
    "| pymilvus==2.0.0rc5     | milvus 2.0.0-rc5  |\n",
    "| sentence_transformers      |               |\n",
    "| psmysql                    | mysql         |\n",
    "| pandas                     |               |\n",
    "| numpy                      |               |\n",
    "\n",
    "We have included a `requirements.txt` file in order to easily satisfy the required packages. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Up and Running"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Installing Packages\n",
    "Install the required python packages with `requirements.txt`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Starting Milvus Server\n",
    "\n",
    "This demo uses Milvus 2.0.0, please refer to the [Install Milvus](https://milvus.io/docs/v2.0.0/install_standalone-docker.md) guide to learn how to use this docker container. For this example we wont be mapping any local volumes. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! wget https://raw.githubusercontent.com/milvus-io/milvus/master/deployments/docker/standalone/docker-compose.yml -O docker-compose.yml\n",
    "! docker-compose up -d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Starting MySQL Server\n",
    "For now, Milvus doesn't support storing string data. Thus, we need a relational database to store questions and answers. In this example, we use MySQL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! docker run -p 3306:3306 -e MYSQL_ROOT_PASSWORD=123456 -d --name qa_mysql mysql:5.7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Confirm Running Servers\n",
    "If Milvus Standalone boots successfully, three running docker containers appear (two infrastructure services and one Milvus service):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! docker-compose ps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! docker logs qa_mysql --tail 6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Code Overview"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Connecting to Servers\n",
    "We first start off by connecting to the servers. In this case the docker containers are running on localhost and the ports are the default ports. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Connectings to Milvus, BERT and Postgresql\n",
    "from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility\n",
    "import pymysql\n",
    "\n",
    "connections.connect(host='localhost', port='19537')\n",
    "conn = pymysql.connect(host='localhost', user='root', port=3306, password='123456', database='mysql',local_infile=True)\n",
    "cursor = conn.cursor()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating Collection and Setting Index\n",
    "#### 1. Creating the Collection  \n",
    "A collection in Milvus is similar to a table in a relational database, and is used for storing all the vectors.  \n",
    "The required parameters for creating a collection are as follows:  \n",
    "- `collection_name`: the name of a collection.  \n",
    "- `dimension`: BERT generates 728-dimensional vectors.  \n",
    "- `index_file_size`: how large each data segment will be within the collection.      \n",
    "- `metric_type`: the distance formula being used to calculate similarity. In this example we are using Inner product (IP)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TABLE_NAME = 'question_answering'\n",
    "\n",
    "#Deleting previouslny stored table for clean run\n",
    "if utility.has_collection(TABLE_NAME):\n",
    "    collection = Collection(name=TABLE_NAME)\n",
    "    collection.drop()\n",
    "\n",
    "field1 = FieldSchema(name=\"id\", dtype=DataType.INT64, descrition=\"int64\", is_primary=True, auto_id=True)\n",
    "field2 = FieldSchema(name=\"embedding\", dtype=DataType.FLOAT_VECTOR, descrition=\"float vector\",dim=768, is_primary=False)\n",
    "schema = CollectionSchema(fields=[field1, field2], description=\"collection description\")\n",
    "collection = Collection(name=TABLE_NAME, schema=schema)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. Setting an Index\n",
    "After creating the collection we want to assign it an index type. This can be done before or after inserting the data. When done before, indexes will be made as data comes in and fills the data segments. In this example we are using IVF_FLAT which requires the 'nlist' parameter. Each index types carries its own parameters. More info about this param can be found [here](https://milvus.io/api-reference/pymilvus-orm/v2.0.0rc1/param.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "default_index = {\"index_type\": \"IVF_FLAT\", \"metric_type\": 'IP', \"params\": {\"nlist\": 200}}\n",
    "collection.create_index(field_name=\"embedding\", index_params=default_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating Table in MySQL  \n",
    "MySQL will be used to store the Milvus ID and its corresponding question-answer combo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Deleting previouslny stored table for clean run\n",
    "drop_table = \"DROP TABLE IF EXISTS \" + TABLE_NAME + \";\"\n",
    "cursor.execute(drop_table)\n",
    "\n",
    "try:\n",
    "    sql = \"CREATE TABLE if not exists \" + TABLE_NAME + \" (id TEXT, question TEXT, answer TEXT);\"\n",
    "    cursor.execute(sql)\n",
    "    print(\"create MySQL table successfully!\")\n",
    "except Exception as e:\n",
    "    print(\"can't create a MySQL table: \", e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Processing and Storing QA Dataset\n",
    "#### 1. Generating Embeddings\n",
    "In this example we are using the sentence_transformer library  to encode the sentence into vectors. This library uses a modified BERT model to generate the embeddings, and in this example we are using a model pretrained using Microsoft's `mpnet`. More info can be found [here](https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "model = SentenceTransformer('/mnt/usersuccess/lym/nlp/model/paraphrase-mpnet-base-v2')\n",
    "\n",
    "# Get questions and answers.\n",
    "data = pd.read_csv('/mnt/usersuccess/lym/example.csv')\n",
    "question_data = data['question'].tolist()\n",
    "answer_data = data['answer'].tolist()\n",
    "\n",
    "sentence_embeddings = model.encode(question_data)\n",
    "sentence_embeddings = normalize(sentence_embeddings).tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. Inserting Vectors into Milvus\n",
    "Since this example dataset contains only 100 vectors, we are inserting all of them as one batch insert."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mr = collection.insert([sentence_embeddings])\n",
    "ids = mr.primary_keys\n",
    "print(len(ids))\n",
    "\n",
    "# status, ids = milv.insert(collection_name=TABLE_NAME, records=sentence_embeddings)\n",
    "# print(status)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. Inserting IDs and Questions-answer Combos into PostgreSQL\n",
    "In order to transfer the data into Postgres, we are creating a new file that combines all the data into a readable format. Once created, we pass this file into the Postgress server through STDIN due to the Postgres container not having access to the file locally. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "\n",
    "# Combine the id of the vector and the question data into a list\n",
    "def format_data(ids, question_data, answer_data):\n",
    "    data = []\n",
    "    for i in range(len(ids)):\n",
    "        value = (str(ids[i]), question_data[i], answer_data[i])\n",
    "        data.append(value)\n",
    "    return data\n",
    "\n",
    "def load_data_to_mysql(cursor, conn, table_name, data):\n",
    "    sql = \"insert into \" + table_name + \" (id,question,answer) values (%s,%s,%s);\"\n",
    "    try:\n",
    "        cursor.executemany(sql, data)\n",
    "        conn.commit()\n",
    "        print(\"MYSQL loads data to table: {} successfully\".format(table_name))\n",
    "    except Exception as e:\n",
    "        print(\"MYSQL ERROR: {} with sql: {}\".format(e, sql))\n",
    "        \n",
    "load_data_to_mysql(cursor, conn, TABLE_NAME, format_data(ids, question_data, answer_data))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Search\n",
    "#### 1. Processing Query\n",
    "When searching for a question, we first put the question through the same model to generate an embedding. Then with that embedding vector we  can search for similar embeddings in Milvus.  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEARCH_PARAM = {'nprobe': 40}\n",
    "\n",
    "query_vec = []\n",
    "\n",
    "question = \"What is AAA?\"\n",
    "\n",
    "query_embeddings = []\n",
    "embed = model.encode(question)\n",
    "embed = embed.reshape(1,-1)\n",
    "embed = normalize(embed)\n",
    "query_embeddings = embed.tolist()\n",
    "\n",
    "collection.load()\n",
    "\n",
    "search_params = {\"metric_type\": 'IP', \"params\": {\"nprobe\": 16}}\n",
    "\n",
    "results = collection.search(query_embeddings, anns_field=\"embedding\", param=search_params, limit=5)\n",
    "\n",
    "# status, results = milv.search(collection_name=TABLE_NAME, query_records=query_embeddings, top_k=5, params=SEARCH_PARAM)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. Getting the Similar Questions\n",
    "There may not have questions that are similar to the given one. So we can set a threshold value, here we use 0.5, and when the most similar distance retrieved is less than this value, a hint that the system doesn't include the relevant question is returned. We then use the result ID's to pull out the similar questions from the Postgres server and print them with their corresponding similarity score."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ids = [str(x.id) for x in results[0]]\n",
    "\n",
    "def search_by_milvus_ids(cursor, ids, table_name):\n",
    "    str_ids = str(ids).replace('[', '').replace(']', '')\n",
    "    sql = \"select question from \" + table_name + \" where id in (\" + str_ids + \") order by field (id,\" + str_ids + \");\"\n",
    "    try:\n",
    "        cursor.execute(sql)\n",
    "        results = cursor.fetchall()\n",
    "        results = [res[0] for res in results]\n",
    "        return results\n",
    "    except Exception as e:\n",
    "        print(\"MYSQL ERROR: {} with sql: {}\".format(e, sql))\n",
    "\n",
    "similar_questions = search_by_milvus_ids(cursor, ids, TABLE_NAME)\n",
    "\n",
    "distances = [x.distance for x in results[0]]\n",
    "\n",
    "res = dict(zip(similar_questions, distances))\n",
    "\n",
    "print('There are similar questions in the database, here are the closest matches:\\n{}'.format(res))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. Get the answer\n",
    "After getting a list of similar questions, choose the one that you feel is closest to yours. Then you can use that question to find the corresponding answer in Postgres."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"select answer from \" + TABLE_NAME + \" where question = '\" + similar_questions[0] + \"';\"\n",
    "\n",
    "cursor.execute(sql)\n",
    "rows=cursor.fetchall()\n",
    "print(\"Question:\")\n",
    "print(question)\n",
    "print(\"Answer:\")\n",
    "print(rows[0][0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv_lym",
   "language": "python",
   "name": "venv_lym"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
